import numpy as np
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import statsmodels.api as sm
import datetime as dt

def weightedBinscatterEvenData(data,xcol,ycol,wcol,nbins):
    dat = data.copy()
    tot_wt = dat[wcol].sum()
    wt_per_bin=np.ceil(tot_wt/nbins)
    dat = dat.sort_values(xcol)
    dat['wt_remaining'] = dat[wcol]
    colidx_wtr = dat.columns.get_loc('wt_remaining')
    curr_bin=0
    curr_wt = 0
    ctr= 0
    t=0
    newBin=False
    results = {xcol:[],ycol:[],wcol:[],'bincol':[]}
    while (t<(len(data))):
        if dat.iloc[t,colidx_wtr] == 0:
            t = t+1
            ctr=0
        else:
            dif = wt_per_bin-curr_wt
            xdat = dat.iloc[t][xcol]
            ydat = dat.iloc[t][ycol]
            wdat = dat.iloc[t]['wt_remaining']
            if wdat>dif:
                wdat=dif
            results[xcol].append(xdat)
            results[ycol].append(ydat)
            results[wcol].append(wdat) 
            results['bincol'].append(curr_bin)
            curr_wt = curr_wt + wdat
            dat.iloc[t,colidx_wtr] = dat.iloc[t,colidx_wtr]-  wdat
            if curr_wt == wt_per_bin:
                curr_wt = 0
                curr_bin +=1
        ctr+=1
        if ctr>1000:
            print('>1K for 1 t')
            break
    res_df = pd.DataFrame(results)
    res_df['wtdx'] = res_df[xcol]*res_df[wcol]
    res_df['wtdy'] = res_df[ycol]*res_df[wcol]
    res_df['wtsum'] = res_df.groupby('bincol')[wcol].transform('sum')
    res_df['smndx'] = res_df['wtdx']/res_df['wtsum']
    res_df['smndy'] = res_df['wtdy']/res_df['wtsum']
    agg = res_df.groupby('bincol').agg({'smndx':sum,'smndy':sum, 'wtsum':np.mean})
    agg.columns = [xcol,ycol,wcol]
    
    return agg,res_df

def weightedBinscatterEven(data, xcol, ycol, wcol, nbins, mrkr='o', clr='k', style='-', xlabel=None, ylabel=None, title=False):
    scatterdata = weightedBinscatterEvenData(data,xcol,ycol,wcol,nbins)
    dat = scatterdata[0]
    X = data[[xcol]]
    X = sm.add_constant(X)
    Y = data[ycol]
    wls_model = sm.WLS(Y,X, weights=data[wcol])
    results = wls_model.fit()
    params = results.params
    linex = [dat[xcol].min(),dat[xcol].max()]
    liney = [dat[xcol].min()*params[xcol]+params['const'], dat[xcol].max()*params[xcol]+params['const']]
    
    fig,ax = plt.subplots(1,1)
    shapelist=['o', 'v', '^', '<', '>', '8', 's', 'p', 'P', '*', 'h', 'H', 'd', 'D', 'X']
    if mrkr in shapelist:
        ax.scatter(dat[xcol], dat[ycol], marker=mrkr, facecolors='none', edgecolors=clr)
    else:
        ax.scatter(dat[xcol], dat[ycol], marker=mrkr, color=clr)
    if xlabel is not None:
        plt.xlabel(xlabel)
    else:
        plt.xlabel(xcol.capitalize())
    if ylabel is not None:
        plt.ylabel(ylabel)
    else:
        plt.ylabel(ycol.capitalize())
    label = r'$\beta$' + '='
    if title==True:
        plt.title('Binscatter weighted by '+ wcol.capitalize() +' w ' +str(nbins) + ' Bins')
    ax.plot(linex, liney, clr+style, label=label + "{0:.3f}".format(params[xcol],3))
    plt.legend()
    return fig, scatterdata

def getBinscatterData(data,x_col,y_col,w_col,nbins, shuff_0=True, best_lin_fit=True):
    dat = data[[x_col,y_col,w_col]].copy()
    dat['wtedx'] = dat[x_col]*dat[w_col]
    dat['wtedy'] = dat[y_col]*dat[w_col]
    total_w = dat[w_col].sum()
    binwt = np.ceil(total_w/nbins)
    dat = dat.sort_values(x_col)
    if shuff_0==True:
        dat[x_col] = dat[x_col].round(2)
        dat_0 = dat.loc[dat[x_col]==0]
        dat_n0 = dat.loc[dat[x_col]!=0]
        dat_0 = dat_0.sample(frac=1, random_state=42)
        dat = dat_0.append(dat_n0)
    dat['cumwt'] = dat[w_col].cumsum()
    dat['bin'] = np.floor(dat.cumwt/binwt)
    dat['binwt'] = dat.groupby('bin')[w_col].transform('sum')
    dat['summandy'] = dat['wtedy']/dat['binwt']
    dat['summandx'] = dat['wtedx']/dat['binwt']
    scatterdata = dat.groupby('bin')[['summandy','summandx']].sum().reset_index()
    if best_lin_fit == True:
        X=data[[x_col]]
        X = sm.add_constant(X)
        Y = data[y_col]
        wls_model = sm.WLS(Y,X, weights=data[w_col])
        results = wls_model.fit()
        params = results.params
        linex = [dat[x_col].min(),dat[x_col].max()]
        liney = [dat[x_col].min()*params[x_col]+params['const'], dat[x_col].max()*params[x_col]+params['const']]
    else:
        linex = None
        liney = None
        params = None
    return scatterdata, linex, liney, params

def weightedBinscatterMulti(datasets,
                            x_col,
                            y_col,
                            w_col,
                            nbins,
                            xlab=None, 
                            ylab=None,
                            datalabels=[''],
                            colors=None,
                            shapes=None,
                            styles=None,
                            xcollist=None,
                            ycollist=None,
                            wcollist=None,
                            title=False,
                            alphas = None,
                            reportCoefs=False,
                            best_lin_fit=False,
                            scale_y=True,
                            add_line_labels=False,
                            line_labels=[''],
                            label_coords=[[0,0]],
                            legend=False,
                            ylims=None):
    shapelist=['o', 'v', '^', '<', '>', '8', 's', 'p', 'P', '*', 'h', 'H', 'd', 'D', 'X']
    fig,ax=plt.subplots(figsize=(10,8))
    if colors is None:
        colors = [None for i in range(len(datasets))]
    if shapes is None:
        shapes = [None for i in range(len(datasets))]
    if styles is None:
        styles = [None for i in range(len(datasets))]
    if alphas is None:
        alphas = [1 for i in range(len(datasets))]
    for i in range(len(datasets)):
        #if datalabels is not None:
         #   label = datalabels[i] + ', ' + r'$\beta$' + '='
        #else:
        label = str(i) + ', ' + r'$\beta$' + '='
        if xcollist is not None:
            x_col = xcollist[i]
        if ycollist is not None:
            y_col = ycollist[i]
        if wcollist is not None:
            w_col = wcollist[i]
        #data = process_var(datasets[i],y_col)
        data = datasets[i]
        print(data.columns)
        scatter_data, linex, liney, params = getBinscatterData(data,x_col,y_col,w_col,nbins,shuff_0=True)
        if reportCoefs:
            linelabel = label + "{0:.3f}".format(params[x_col],3)
        if shapes[i] in shapelist:
            ax.scatter(scatter_data['summandx'], scatter_data['summandy'],color=colors[i],alpha=alphas[i], facecolors='none', edgecolors=colors[i], marker=shapes[i])
        else:
            ax.scatter(scatter_data['summandx'], scatter_data['summandy'], color=colors[i], alpha=alphas[i],marker=shapes[i])
        if best_lin_fit == True:
            ax.plot(linex, liney, linewidth=2, color=colors[i],alpha = alphas[i], linestyle=styles[i])
        
    if scale_y==True:    
        ## scale up y axis to percentages (instead of decimals)
        scale_y=0.01
        ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
        ax.yaxis.set_major_formatter(ticks_y)
    
    plt.xlabel(x_col.capitalize())
    plt.ylabel(y_col.capitalize())
    if title==True:
        plt.title('Binscatter weighted by '+ w_col.capitalize() +' w ' +str(nbins) + ' Bins')
    if legend==True:
        plt.legend()
    if xlab is not None:
        plt.xlabel(xlab)
    if ylab is not None:
        plt.ylabel(ylab)
    if add_line_labels:
        for i in range(len(line_labels)):
            ax.annotate(line_labels[i],(label_coords[i][0],label_coords[i][1]), size=16)
    if ylims is not None:
        ax.set_ylim(ylims[0], ylims[1])
    #if title is not None:
        #plt.title(title)
    return fig, scatter_data

def weightedBinscatterQnD(data, xcol, ycol, wcol, nbins, mrkr='o', clr='k', style='-', xlabel=None, ylabel=None, title=False, shuff_0=True,includeCoef=True,best_lin_fit= False,
                            scale_y=True,
                            set_ymin_zero=False,
                            set_ymax_one=False,
                            add_line_labels=False,
                            line_labels=[''],
                            label_coords=[[0,0]]):
    """A quick and dirty version of the weighted binscatter. *not* guaranteed to have exactly equal bins, though should be close."""
    x_col = xcol
    y_col = ycol
    w_col = wcol    
    dat = data[[x_col,y_col,w_col]].copy()
    dat['wtedx'] = dat[x_col]*dat[w_col]
    dat['wtedy'] = dat[y_col]*dat[w_col]
    totalw =  dat[w_col].sum()
    binwt = np.ceil(totalw/nbins)
    dat = dat.sort_values(x_col)
    if shuff_0==True:
        dat[x_col] = dat[x_col].astype(float)
        dat[x_col] = dat[x_col].round(2)
        dat_0 = dat.loc[dat[x_col]==0]
        dat_n0 = dat.loc[dat[x_col]!=0]
        dat_0 = dat_0.sample(frac=1)
        dat = dat_0.append(dat_n0)
    dat['cumwt'] = dat[w_col].cumsum()
    dat['bin'] = np.floor(dat.cumwt/binwt)
    dat['binwt'] = dat.groupby('bin')[w_col].transform('sum')
    dat['summandy'] = dat['wtedy']/dat['binwt']
    dat['summandx'] = dat['wtedx']/dat['binwt']
    print(dat.summandx.value_counts())
    print(dat.columns.tolist())
    scatter_data = dat.groupby('bin').sum().reset_index()
    print(scatter_data.columns.tolist())
    if best_lin_fit == True:
        X = data[[x_col]]
        X = sm.add_constant(X)
        Y = data[y_col]
        wls_model = sm.WLS(Y,X, weights=data[w_col])
        results = wls_model.fit()
        params = results.params
        linex = [dat[x_col].min(),dat[x_col].max()]
        liney = [dat[x_col].min()*params[x_col]+params['const'], dat[x_col].max()*params[x_col]+params['const']]
    fig,ax = plt.subplots(1, 1)
    shapelist=['o', 'v', '^', '<', '>', '8', 's', 'p', 'P', '*', 'h', 'H', 'd', 'D', 'X']
    if mrkr in shapelist:
        ax.scatter(scatter_data['summandx'], scatter_data['summandy'], marker=mrkr, facecolors='none', edgecolors=clr)
    else:
        ax.scatter(scatter_data['summandx'], scatter_data['summandy'], marker=mrkr, color=clr)
    if xlabel is not None:
        plt.xlabel(xlabel)
    else:
        plt.xlabel(x_col.capitalize())
    if ylabel is not None:
        plt.ylabel(ylabel)
    else:
        plt.ylabel(y_col.capitalize())
    if title==True:
        plt.title('Binscatter weighted by '+ w_col.capitalize() +' w ' +str(nbins) + ' Bins')
    label = r'$\beta$' + '='
    if includeCoef==True and best_lin_fit==True:
        ax.plot(linex, liney, clr + style, label = label + "{0:.3f}".format(params[x_col],3))
        plt.legend()
    elif includeCoef==False and best_lin_fit==True:
        print(params[x_col])
        ax.plot(linex, liney, clr + style)
    if scale_y==True:    
        ## scale up y axis to percentages (instead of decimals)
        scale_y=0.01
        ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
        ax.yaxis.set_major_formatter(ticks_y)
    if set_ymin_zero==True:
        plt.ylim(ymin=0)
    if set_ymax_one==True:
        plt.ylim(ymax=1)
    if add_line_labels:
        for i in range(len(line_labels)):
            ax.annotate(line_labels[i],(label_coords[i][0],label_coords[i][1]), size=16)
    return fig, scatter_data
